{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "d2d49e3c",
   "metadata": {},
   "source": [
    "### Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "a9214247",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import warnings\n",
    "from tqdm import tqdm\n",
    "from torch.autograd import Variable\n",
    "from sklearn.metrics import mean_absolute_error"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f9eec470",
   "metadata": {},
   "source": [
    "### Auglichem imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "f159194c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from auglichem.crystal import (PerturbStructureTransformation,\n",
    "                               RotationTransformation,\n",
    "                               SwapAxesTransformation,\n",
    "                               TranslateSitesTransformation,\n",
    "                               SupercellTransformation,\n",
    ")\n",
    "from auglichem.crystal.data import CrystalDatasetWrapper\n",
    "from auglichem.crystal.models import GINet, GCN"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2d43e54e",
   "metadata": {},
   "source": [
    "### Set up dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "6068e947",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading data to: ./data_download/lanths...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "99it [00:00, 253.36it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting zipfile...\n",
      "Removing zipfile...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|▏                                                                                                  | 8/3332 [00:00<02:48, 19.76it/s]/home/mlai/anaconda3/envs/dev_auglichem/lib/python3.8/site-packages/pymatgen/io/cif.py:1165: UserWarning: Issues encountered while parsing CIF: Some fractional co-ordinates rounded to ideal values to avoid issues with finite precision.\n",
      "  warnings.warn(\"Issues encountered while parsing CIF: %s\" % \"\\n\".join(self.warnings))\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████| 3332/3332 [01:20<00:00, 41.35it/s]\n",
      "/home/mlai/anaconda3/envs/dev_auglichem/lib/python3.8/site-packages/torch_geometric/deprecation.py:13: UserWarning: 'data.DataLoader' is deprecated, use 'loader.DataLoader' instead\n",
      "  warnings.warn(out)\n"
     ]
    }
   ],
   "source": [
    "# Create transformation\n",
    "transforms = [\n",
    "        PerturbStructureTransformation(distance=0.1, min_distance=0.01),\n",
    "        RotationTransformation(axis=[0,0,1], angle=90),\n",
    "        SwapAxesTransformation(),\n",
    "        TranslateSitesTransformation(indices_to_move=[0], translation_vector=[1,0,0],\n",
    "                                     vector_in_frac_coords=True),\n",
    "        SupercellTransformation(scaling_matrix=[[1,0,0],[0,1,0],[0,0,1]]),\n",
    "]\n",
    "\n",
    "# Initialize dataset object\n",
    "dataset = CrystalDatasetWrapper(\"lanthanides\", batch_size=128,\n",
    "                                valid_size=0.1, test_size=0.1)\n",
    "\n",
    "# Get train/valid/test splits as loaders\n",
    "train_loader, valid_loader, test_loader = dataset.get_data_loaders(transform=transforms)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "89d4a737",
   "metadata": {},
   "source": [
    "### Initialize model with task from data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "6b8bc773",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get model\n",
    "model = GINet() # Note: GCN and GINet are interchangeable in use cases\n",
    "\n",
    "# Uncomment the following line to use cuda\n",
    "#model.cuda()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6dabddda",
   "metadata": {},
   "source": [
    "### Initialize traning loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "ff61bb9d",
   "metadata": {},
   "outputs": [],
   "source": [
    "criterion = torch.nn.MSELoss()\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3f45f87f",
   "metadata": {},
   "source": [
    "### Train the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "827c88e9",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "56it [01:09,  1.24s/it]\n"
     ]
    },
    {
     "ename": "ValueError",
     "evalue": "Invalid cif file with no structures!",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
      "\u001b[0;32m/tmp/ipykernel_19004/3864175603.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      2\u001b[0m     \u001b[0mwarnings\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msimplefilter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"ignore\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m     \u001b[0;32mfor\u001b[0m \u001b[0mepoch\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m         \u001b[0;32mfor\u001b[0m \u001b[0mbn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtqdm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_loader\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      5\u001b[0m             \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      6\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/dev_auglichem/lib/python3.8/site-packages/tqdm/std.py\u001b[0m in \u001b[0;36m__iter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m   1178\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1179\u001b[0m         \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1180\u001b[0;31m             \u001b[0;32mfor\u001b[0m \u001b[0mobj\u001b[0m \u001b[0;32min\u001b[0m \u001b[0miterable\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1181\u001b[0m                 \u001b[0;32myield\u001b[0m \u001b[0mobj\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1182\u001b[0m                 \u001b[0;31m# Update and possibly print the progressbar.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/dev_auglichem/lib/python3.8/site-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    519\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_sampler_iter\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    520\u001b[0m                 \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_reset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 521\u001b[0;31m             \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_next_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    522\u001b[0m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_num_yielded\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    523\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_dataset_kind\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0m_DatasetKind\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mIterable\u001b[0m \u001b[0;32mand\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m\\\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/dev_auglichem/lib/python3.8/site-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m_next_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    559\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m_next_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    560\u001b[0m         \u001b[0mindex\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_next_index\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m  \u001b[0;31m# may raise StopIteration\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 561\u001b[0;31m         \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_dataset_fetcher\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfetch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mindex\u001b[0m\u001b[0;34m)\u001b[0m  \u001b[0;31m# may raise StopIteration\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    562\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_pin_memory\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    563\u001b[0m             \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_utils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpin_memory\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpin_memory\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/dev_auglichem/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py\u001b[0m in \u001b[0;36mfetch\u001b[0;34m(self, possibly_batched_index)\u001b[0m\n\u001b[1;32m     47\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mfetch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     48\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mauto_collation\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 49\u001b[0;31m             \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0midx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     50\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     51\u001b[0m             \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/dev_auglichem/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py\u001b[0m in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m     47\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mfetch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     48\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mauto_collation\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 49\u001b[0;31m             \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0midx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     50\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     51\u001b[0m             \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/dev_auglichem/lib/python3.8/site-packages/auglichem/crystal/data/_crystal_dataset.py\u001b[0m in \u001b[0;36m__getitem__\u001b[0;34m(self, idx)\u001b[0m\n\u001b[1;32m    501\u001b[0m             \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_getitem_crystal\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    502\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 503\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_getitem_knn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    504\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    505\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/dev_auglichem/lib/python3.8/site-packages/auglichem/crystal/data/_crystal_dataset.py\u001b[0m in \u001b[0;36m_getitem_knn\u001b[0;34m(self, idx)\u001b[0m\n\u001b[1;32m    467\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    468\u001b[0m         \u001b[0;31m# read cif using pymatgen\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 469\u001b[0;31m         \u001b[0maug_crys\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mStructure\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfrom_file\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0maugment_cryst_path\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    470\u001b[0m         \u001b[0mpos\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0maug_crys\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfrac_coords\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    471\u001b[0m         \u001b[0matom_indices\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0maug_crys\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0matomic_numbers\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/dev_auglichem/lib/python3.8/site-packages/pymatgen/core/structure.py\u001b[0m in \u001b[0;36mfrom_file\u001b[0;34m(cls, filename, primitive, sort, merge_tol)\u001b[0m\n\u001b[1;32m   2468\u001b[0m             \u001b[0mcontents\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mread\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2469\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mfnmatch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfname\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlower\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"*.cif*\"\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mfnmatch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfname\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlower\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"*.mcif*\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2470\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mcls\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfrom_str\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcontents\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfmt\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"cif\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mprimitive\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mprimitive\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msort\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0msort\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmerge_tol\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmerge_tol\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   2471\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mfnmatch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"*POSCAR*\"\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mfnmatch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"*CONTCAR*\"\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mfnmatch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"*.vasp\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2472\u001b[0m             s = cls.from_str(\n",
      "\u001b[0;32m~/anaconda3/envs/dev_auglichem/lib/python3.8/site-packages/pymatgen/core/structure.py\u001b[0m in \u001b[0;36mfrom_str\u001b[0;34m(cls, input_string, fmt, primitive, sort, merge_tol)\u001b[0m\n\u001b[1;32m   2389\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2390\u001b[0m             \u001b[0mparser\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mCifParser\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfrom_string\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput_string\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2391\u001b[0;31m             \u001b[0ms\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mparser\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_structures\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mprimitive\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mprimitive\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   2392\u001b[0m         \u001b[0;32melif\u001b[0m \u001b[0mfmt_low\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m\"poscar\"\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2393\u001b[0m             \u001b[0;32mfrom\u001b[0m \u001b[0mpymatgen\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mio\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvasp\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mPoscar\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/dev_auglichem/lib/python3.8/site-packages/pymatgen/io/cif.py\u001b[0m in \u001b[0;36mget_structures\u001b[0;34m(self, primitive, symmetrized)\u001b[0m\n\u001b[1;32m   1165\u001b[0m             \u001b[0mwarnings\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwarn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Issues encountered while parsing CIF: %s\"\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0;34m\"\\n\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwarnings\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1166\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstructures\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1167\u001b[0;31m             \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Invalid cif file with no structures!\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1168\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mstructures\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1169\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mValueError\u001b[0m: Invalid cif file with no structures!"
     ]
    }
   ],
   "source": [
    "with warnings.catch_warnings():\n",
    "    warnings.simplefilter(\"ignore\")\n",
    "    for epoch in range(1):\n",
    "        for bn, data in tqdm(enumerate(train_loader)):        \n",
    "            optimizer.zero_grad()\n",
    "\n",
    "            # Comment out the following line and uncomment the line after for cuda\n",
    "            pred = model(data)\n",
    "            #pred = model(data.cuda())\n",
    "            \n",
    "            loss = criterion(pred, data.y)\n",
    "\n",
    "            loss.backward()\n",
    "            optimizer.step()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "558838e0",
   "metadata": {},
   "source": [
    "### Test the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a12118ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate(model, test_loader, validation=False):\n",
    "    with warnings.catch_warnings():\n",
    "        warnings.simplefilter(\"ignore\")\n",
    "        with torch.no_grad():\n",
    "            model.eval()\n",
    "            preds = torch.Tensor([])\n",
    "            targets = torch.Tensor([])\n",
    "            for data in test_loader:\n",
    "                pred = model(data)\n",
    "                #pred = model(data.cuda())\n",
    "                preds = torch.cat((preds, pred.cpu()))\n",
    "                targets = torch.cat((targets, data.y.cpu()))\n",
    "\n",
    "            mae = mean_absolute_error(preds, targets)   \n",
    "        \n",
    "        set_str = \"VALIDATION\" if(validation) else \"TEST\"\n",
    "        print(\"{0} MAE: {1:.3f}\".format(set_str, mae))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "02d06ffa",
   "metadata": {},
   "outputs": [],
   "source": [
    "evaluate(model, valid_loader, validation=True)\n",
    "evaluate(model, test_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c272546e",
   "metadata": {},
   "source": [
    "### Model saving/loading example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3016e288",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save model\n",
    "torch.save(model.state_dict(), \"./example_ginet\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a53ce66f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Instantiate new model and evaluate\n",
    "model = GINet()\n",
    "\n",
    "evaluate(model, valid_loader, validation=True)\n",
    "evaluate(model, test_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4dea9d36",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load saved model and evaluate\n",
    "model.load_state_dict(torch.load(\"./example_ginet\"))\n",
    "evaluate(model, valid_loader, validation=True)\n",
    "evaluate(model, test_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "79f9b873",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
